Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc][Polish] gemm optimize by 2d thread tile #56

Merged
merged 1 commit into from
Dec 3, 2024

Conversation

muyuuuu
Copy link
Contributor

@muyuuuu muyuuuu commented Dec 2, 2024

  1. 1d Thread Tile 优化的时候,我看图中的意思是一次读 4 行 A,一列 B,写 4 个 C
  2. 2d Thread Tile 补充了这个例子,多读几行 A 扩充到多读几列 B
  3. 代码中引入未知内容:__launch_bounds__((BM * BN) / (TM * TN), 1),好像没有相关的解释。才疏学浅,我就不乱讲了,大佬可否补充一下?

@@ -88,7 +88,7 @@ nvcc -o matmul_shared matmul_shared.cu

![picture 0](images/3b9ca1d09a35e62b14f73b56e21b988d379bf0b38b8af6d4d9b17d9f46663c1c.png)

上图中,A 和 B 都是 7x7 的矩阵。当每一个线程只计算一个结果的时候,我们需要从 A 中读取 7 个数据,从 B 中读取 7 个数据,从 C 中读取 1 个数据,然后写一次 C。这样的话,每个线程需要读取 15 个数据,写一次数据。如果我们每一个线程计算 4 个结果,那么我们需要从 A 中读取 14 个数据,从 B 中读取 14 个数据,从 C 中读取 4 个数据,然后写 4 次 C。这样的话,每个线程需要读取 32 个数据,写 4 次数据。计算每个线程的平方结果比计算结果的列更有效,因为这样我们可以共享更多的输入。
上图中,A 和 B 都是 7x7 的矩阵。当每一个线程只计算一个结果的时候,我们需要从 A 中读取 7 个数据,从 B 中读取 7 个数据,从 C 中读取 1 个数据,然后写一次 C。这样的话,每个线程需要读取 15 个数据,写一次数据。如果我们每一个线程计算 4 个结果,那么我们需要从 A 中读取 28 个数据,从 B 中读取 7 个数据,从 C 中读取 4 个数据,然后写 4 次 C。这样的话,每个线程需要读取 32 个数据,写 4 次数据。计算每个线程的平方结果比计算结果的列更有效,因为这样我们可以共享更多的输入。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 A 就是读取 14 个数据,往 C 里面写 4 个数据,就需要读 A 的俩行,B 的俩行,每行 7 个数据

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看错图了,看到下面那个图了...

@@ -6,15 +6,21 @@

在介绍二维 Thread Tile 之前,我们先来回顾一下一维 Thread Tile 的优化方法。在初级系列中,我们使用了一维线程块来优化矩阵乘法的性能,我们将矩阵乘法的计算任务分配给了一维线程块,每个线程块负责计算一个小的矩阵块。这样做的好处是可以充分利用共享内存,减少全局内存的访问次数,从而提高矩阵乘法的性能。

我们在每个线程中计算了一维的矩阵块。想要继续优化这个 Kernel 的性能,我们可以使用二维线程块来计算二维的矩阵块。
还记得一维 Thread Tile 中的例子吗?如果输入的 A 和 B 都是 8x8 的矩阵:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和之前一样保持 7x7 的矩阵是不是比较好?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改,可能一开始考虑到 7x7 不会被整除

@AndSonder
Copy link
Collaborator

AndSonder commented Dec 3, 2024

__launch_bounds__((BM * BN) / (TM * TN), 1) 的问题:

如果内核函数对寄存器或共享内存的使用量较高,使用 __launch_bounds__ 可以帮助编译器更好地管理资源,确保在每个多处理器上驻留足够数量的线程块

这个后续感觉可以在新手村系列加一篇文章单独介绍一下,做一点简单的小实验对比下它的效果,你要是有时间欢迎提 PR 😄

@muyuuuu
Copy link
Contributor Author

muyuuuu commented Dec 3, 2024

使用 __launch_bounds__ 可以帮助编译器更好地管理资源

查了下文档:

 __launch_bounds__(max_threads_per_block, min_blocks_per_sm)

有个疑问,运行 kernel 的时候会指定 block size,也就是每个 block 的线程数是确定的,为啥还有 max_threads_per_block 呢?

@AndSonder
Copy link
Collaborator

使用 __launch_bounds__ 可以帮助编译器更好地管理资源

查了下文档:

 __launch_bounds__(max_threads_per_block, min_blocks_per_sm)

有个疑问,运行 kernel 的时候会指定 block size,也就是每个 block 的线程数是确定的,为啥还有 max_threads_per_block 呢?

__launch_bounds__(max_threads_per_block, min_blocks_per_sm) 是一种编译器指示

在编译阶段,编译器并不知道你在运行时会用什么线程数。因此,它需要依赖类似 __launch_bounds__ 的指示,假设可能的线程数上限,以便对内核的资源使用进行优化。

如果编译器认为你可能会用到 1024 个线程,而实际运行时只用到 256 个线程,寄存器可能被过度保守地分配,降低了效率。

max_threads_per_block 还提供了一层安全检查,如果运行时的 threads 超过 max_threads_per_block,内核调用会失败。

两者的目标不同:<<<blocks, threads>>> 直接影响运行时行为,而 max_threads_per_block 是为了优化编译期资源分配

Copy link
Collaborator

@AndSonder AndSonder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@AndSonder AndSonder merged commit a6b8d59 into PaddleJitLab:develop Dec 3, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants